Skip to content

Conversation

@TianHao324
Copy link
Contributor

@TianHao324 TianHao324 commented Jan 19, 2026

Summary

Add NPU support for the embedding.

  • Implements a flattened, grid-stride Triton kernel for embedding forward/backward to improve scalability and reduce launch overhead on Ascend NPUs.
  • Uses UB-aware tiling (compute_default_tiling_strategy) and NPU vector core count to dynamically select block size and grid size for better performance stability.

Testing Done

I tested swiglu by following method and all cases passed:

  • python benchmark/scripts/benchmark_embedding.py
  • pytest -v test/transformers/test_embedding.py
  • Hardware Type: Ascend NPU 910B4
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@TianHao324
Copy link
Contributor Author

test_embedding result:
image

@TianHao324
Copy link
Contributor Author

Hi @Tcc0403, could you please help me review my code?

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the current implementation is quite inefficient. I've left some comments about some possible issues it might have.

)


def get_optimal_block_size(total_elements, is_backward: bool):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does is_backward do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, at first I intended to distinguish the forward and backward directions. Later, I realized their logic was quite similar and I forgot to delete it.

Comment on lines 14 to 20
@triton.jit
def embedding_forward_kernel(
embeddings_ptr,
indices_ptr,
output_ptr,
total_elements,
n_elements,
embedding_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the original implementation with 2 block sizes for tile shape is more readable and more efficient.

persistant grid loop is fine, but the way this kernel loading embedding seems to be uncoalesced at some point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, there will be some dim_idx not consecutive if BLOCK_SIZE is not multiple of embedding_dim. It will make the second tl.load trying to access different rows within a warp, as well as the last store.

Make these offsets created with 2d block size is more readable and efficient since we can avoid the uncoalesced access mentioned above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed it to 2D block. After testing, it has indeed shown much better performance. The issues mentioned below have also been fixed. Could you please review it for me again?

Comment on lines 110 to 126
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype_size should be embedding.dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

block_size = tile_shapes[0][0]
return block_size
else:
return triton.next_power_of_2(total_elements)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fallback value should be workable, triton.next_power_of_2(total_elements) is too large.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

embeddings_ptr + embedding_offsets,
mask=final_mask,
other=0.0,
).to(tl.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any consideration why we need to upcast it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 20, 2026

Could you attach the benchmark results for reference?

@TianHao324
Copy link
Contributor Author

TianHao324 commented Jan 20, 2026

Could you attach the benchmark results for reference?

Currently, compared to the previous version, the performance has improved by 4 to 5 times. However, it still has a significant difference compared to HuggingFace. But I attempted to use the original GPU code (only addressing the UB issue), and the performance was nearly the same (the results are shown below).

[
  {
    "kernel_name": "embedding",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      42.66733932495117,
      43.84379959106445,
      43.834800720214844,
      43.53144836425781,
      43.65476989746094,
      42.79145050048828,
      44.18817138671875,
      44.12928009033203
    ],
    "y_values_20": [
      42.66537094116211,
      43.84306716918945,
      43.83445358276367,
      43.531349182128906,
      43.65372085571289,
      42.7907829284668,
      44.18741989135742,
      44.12871551513672
    ],
    "y_values_80": [
      42.669307708740234,
      43.84453201293945,
      43.835147857666016,
      43.531551361083984,
      43.655818939208984,
      42.792118072509766,
      44.18891906738281,
      44.129844665527344
    ],
    "timestamp": "2026-01-20 10:33:22",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "embedding",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "V",
    "x_label": "embedding dimension",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.08077999949455261,
      0.091559998691082,
      0.1134599968791008,
      0.14830000698566437,
      0.1863200068473816,
      0.21172000467777252,
      0.22543999552726746,
      0.2385600060224533
    ],
    "y_values_20": [
      0.08038800209760666,
      0.09114000201225281,
      0.11287999898195267,
      0.14771999418735504,
      0.18585199117660522,
      0.21121999621391296,
      0.22499999403953552,
      0.23792000114917755
    ],
    "y_values_80": [
      0.08191999793052673,
      0.09239999949932098,
      0.11416800320148468,
      0.14903999865055084,
      0.18700000643730164,
      0.21240000426769257,
      0.22592000663280487,
      0.23929999768733978
    ],
    "timestamp": "2026-01-20 10:33:35",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 32, \"T\": 512, \"D\": 768, \"dtype\": \"torch.float32\"}",
    "liger_version": "0.0.0"
  },

Implement using GPU:
ScreenShot_2026-01-20_181912_777

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with merging this PR since it's an experimental operator and isn’t used in any patching path. That said, we should probably open a performance issue for this kernel and track it for future improvements.

@TianHao324
Copy link
Contributor Author

I'm fine with merging this PR since it's an experimental operator and isn’t used in any patching path. That said, we should probably open a performance issue for this kernel and track it for future improvements.

You're right. In fact, we do have plans to improve the performance. Currently, we need to first support these operators on the NPU and explore ways to optimize the performance as much as possible.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 21, 2026

Could you open an issue with benchmarking results so we can track this performance problem and allow future contributors to work on it?

@TianHao324
Copy link
Contributor Author

Could you open an issue with benchmarking results so we can track this performance problem and allow future contributors to work on it?

Sure! #1036

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 21, 2026

Thank you!

@Tcc0403 Tcc0403 merged commit 57e98d3 into linkedin:main Jan 21, 2026
3 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants